K-Nearest Neighbors is a supervised machine learning algorithm to predict whether a data point belongs to a class. The training data is labeled and then the data point looks at the nearest K number of points. The class with the largest number of occurrences within the K closest data points is then assumed to be the correct class. We use labeled pumpkin seed data to create a model and predict the correct class of the testing data. We use the class package in R to access the KNN algorithm. Then we scale, split, and evaluate the accuracy of the model. Scatterplots are used for visualizations.

Original Dataset: https://www.kaggle.com/datasets/muratkokludataset/pumpkin-seeds-dataset

Install Packages

install.packages('tidyverse', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'tidyverse' successfully unpacked and MD5 sums checked
## 
## The downloaded binary packages are in
##  C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('ggplot2', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'ggplot2' successfully unpacked and MD5 sums checked
## 
## The downloaded binary packages are in
##  C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('class', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'class' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'class'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\class\libs\x64\class.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\class\libs\x64\class.dll:
## Permission denied
## Warning: restored 'class'
## 
## The downloaded binary packages are in
##  C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages('readxl', repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'readxl' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'readxl'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\readxl\libs\x64\readxl.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\readxl\libs\x64\readxl.dll:
## Permission denied
## Warning: restored 'readxl'
## 
## The downloaded binary packages are in
##  C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
install.packages("caret", repos = "http://cran.us.r-project.org")
## Installing package into 'C:/Users/Steve/AppData/Local/R/win-library/4.2'
## (as 'lib' is unspecified)
## package 'caret' successfully unpacked and MD5 sums checked
## Warning: cannot remove prior installation of package 'caret'
## Warning in file.copy(savedcopy, lib, recursive = TRUE): problem copying C:
## \Users\Steve\AppData\Local\R\win-library\4.2\00LOCK\caret\libs\x64\caret.dll
## to C:\Users\Steve\AppData\Local\R\win-library\4.2\caret\libs\x64\caret.dll:
## Permission denied
## Warning: restored 'caret'
## 
## The downloaded binary packages are in
##  C:\Users\Steve\AppData\Local\Temp\RtmpQTTRME\downloaded_packages
library(tidyverse)
## ── Attaching packages
## ───────────────────────────────────────
## tidyverse 1.3.2 ──
## ✔ ggplot2 3.4.0      ✔ purrr   0.3.5 
## ✔ tibble  3.1.8      ✔ dplyr   1.0.10
## ✔ tidyr   1.2.1      ✔ stringr 1.4.1 
## ✔ readr   2.1.3      ✔ forcats 0.5.2 
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(ggplot2)
library(class)
library(readxl)
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## 
## The following object is masked from 'package:purrr':
## 
##     lift

Read xlsx

pumpkin_data <- read_xlsx('Pumpkin_Seeds_Dataset.xlsx')

Str, Summary, Head

str(pumpkin_data)
## tibble [2,500 × 13] (S3: tbl_df/tbl/data.frame)
##  $ Area             : num [1:2500] 56276 76631 71623 66458 66107 ...
##  $ Perimeter        : num [1:2500] 888 1068 1083 992 998 ...
##  $ Major_Axis_Length: num [1:2500] 326 417 436 382 384 ...
##  $ Minor_Axis_Length: num [1:2500] 220 234 211 223 220 ...
##  $ Convex_Area      : num [1:2500] 56831 77280 72663 67118 67117 ...
##  $ Equiv_Diameter   : num [1:2500] 268 312 302 291 290 ...
##  $ Eccentricity     : num [1:2500] 0.738 0.828 0.875 0.812 0.819 ...
##  $ Solidity         : num [1:2500] 0.99 0.992 0.986 0.99 0.985 ...
##  $ Extent           : num [1:2500] 0.745 0.715 0.74 0.74 0.675 ...
##  $ Roundness        : num [1:2500] 0.896 0.844 0.767 0.849 0.834 ...
##  $ Aspect_Ration    : num [1:2500] 1.48 1.78 2.07 1.71 1.74 ...
##  $ Compactness      : num [1:2500] 0.821 0.749 0.693 0.762 0.756 ...
##  $ Class            : chr [1:2500] "Çerçevelik" "Çerçevelik" "Çerçevelik" "Çerçevelik" ...
summary(pumpkin_data)
##       Area          Perimeter      Major_Axis_Length Minor_Axis_Length
##  Min.   : 47939   Min.   : 868.5   Min.   :320.8     Min.   :152.2    
##  1st Qu.: 70765   1st Qu.:1048.8   1st Qu.:415.0     1st Qu.:211.2    
##  Median : 79076   Median :1123.7   Median :449.5     Median :224.7    
##  Mean   : 80658   Mean   :1130.3   Mean   :456.6     Mean   :225.8    
##  3rd Qu.: 89758   3rd Qu.:1203.3   3rd Qu.:492.7     3rd Qu.:240.7    
##  Max.   :136574   Max.   :1559.5   Max.   :661.9     Max.   :305.8    
##   Convex_Area     Equiv_Diameter   Eccentricity       Solidity     
##  Min.   : 48366   Min.   :247.1   Min.   :0.4921   Min.   :0.9186  
##  1st Qu.: 71512   1st Qu.:300.2   1st Qu.:0.8317   1st Qu.:0.9883  
##  Median : 79872   Median :317.3   Median :0.8637   Median :0.9903  
##  Mean   : 81508   Mean   :319.3   Mean   :0.8609   Mean   :0.9895  
##  3rd Qu.: 90798   3rd Qu.:338.1   3rd Qu.:0.8970   3rd Qu.:0.9915  
##  Max.   :138384   Max.   :417.0   Max.   :0.9481   Max.   :0.9944  
##      Extent         Roundness      Aspect_Ration    Compactness    
##  Min.   :0.4680   Min.   :0.5546   Min.   :1.149   Min.   :0.5608  
##  1st Qu.:0.6589   1st Qu.:0.7519   1st Qu.:1.801   1st Qu.:0.6635  
##  Median :0.7130   Median :0.7977   Median :1.984   Median :0.7077  
##  Mean   :0.6932   Mean   :0.7915   Mean   :2.042   Mean   :0.7041  
##  3rd Qu.:0.7402   3rd Qu.:0.8343   3rd Qu.:2.262   3rd Qu.:0.7435  
##  Max.   :0.8296   Max.   :0.9396   Max.   :3.144   Max.   :0.9049  
##     Class          
##  Length:2500       
##  Class :character  
##  Mode  :character  
##                    
##                    
## 
head(pumpkin_data)
## # A tibble: 6 × 13
##    Area Perimeter Major…¹ Minor…² Conve…³ Equiv…⁴ Eccen…⁵ Solid…⁶ Extent Round…⁷
##   <dbl>     <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>   <dbl>  <dbl>   <dbl>
## 1 56276      888.    326.    220.   56831    268.   0.738   0.990  0.745   0.896
## 2 76631     1068.    417.    234.   77280    312.   0.828   0.992  0.715   0.844
## 3 71623     1083.    436.    211.   72663    302.   0.875   0.986  0.74    0.767
## 4 66458      992.    382.    223.   67118    291.   0.812   0.990  0.740   0.849
## 5 66107      998.    384.    220.   67117    290.   0.819   0.985  0.675   0.834
## 6 73191     1041.    406.    231.   73969    305.   0.822   0.990  0.716   0.848
## # … with 3 more variables: Aspect_Ration <dbl>, Compactness <dbl>, Class <chr>,
## #   and abbreviated variable names ¹​Major_Axis_Length, ²​Minor_Axis_Length,
## #   ³​Convex_Area, ⁴​Equiv_Diameter, ⁵​Eccentricity, ⁶​Solidity, ⁷​Roundness

Create a scatterplot of the pumpkin seeds

ggplot(data=pumpkin_data) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class,colour=Class)) + 
  labs(y = "Aspect Ratio",x="Area") + 
  ggtitle("Scatter Plot of Pumpkin Seeds") + theme(plot.title = element_text(hjust = 0.5)) 

Change Class to numeric factor

pumpkin_data$Species <- ifelse(pumpkin_data$Class == "Çerçevelik",1,2)
str(pumpkin_data)
## tibble [2,500 × 14] (S3: tbl_df/tbl/data.frame)
##  $ Area             : num [1:2500] 56276 76631 71623 66458 66107 ...
##  $ Perimeter        : num [1:2500] 888 1068 1083 992 998 ...
##  $ Major_Axis_Length: num [1:2500] 326 417 436 382 384 ...
##  $ Minor_Axis_Length: num [1:2500] 220 234 211 223 220 ...
##  $ Convex_Area      : num [1:2500] 56831 77280 72663 67118 67117 ...
##  $ Equiv_Diameter   : num [1:2500] 268 312 302 291 290 ...
##  $ Eccentricity     : num [1:2500] 0.738 0.828 0.875 0.812 0.819 ...
##  $ Solidity         : num [1:2500] 0.99 0.992 0.986 0.99 0.985 ...
##  $ Extent           : num [1:2500] 0.745 0.715 0.74 0.74 0.675 ...
##  $ Roundness        : num [1:2500] 0.896 0.844 0.767 0.849 0.834 ...
##  $ Aspect_Ration    : num [1:2500] 1.48 1.78 2.07 1.71 1.74 ...
##  $ Compactness      : num [1:2500] 0.821 0.749 0.693 0.762 0.756 ...
##  $ Class            : chr [1:2500] "Çerçevelik" "Çerçevelik" "Çerçevelik" "Çerçevelik" ...
##  $ Species          : num [1:2500] 1 1 1 1 1 1 1 1 1 1 ...
Pumpkin_data1 <- pumpkin_data %>% select (-13)

Scaling the data (subtract mean and divide by standard deviation)

Pumpkin_data1[,1:12] <- scale(Pumpkin_data1[,1:12])

Set Seed for sampling and code reproduction

set.seed(1234)

Use sample function for 80/20 split

index <- sample(2,nrow(Pumpkin_data1),replace = TRUE, prob=c(0.8,0.2))

Split into Training and Testing Data

training_data <- Pumpkin_data1[index==1,]
testing_data <-Pumpkin_data1[index==2,]
training_label <- training_data %>% select(13)
testing_label <- testing_data %>% select(13)
training_data <- training_data %>% select(-13)
testing_data <- testing_data %>%  select(-13)

Use square root of number of rows as K, we get K=45.

using odd K is better for the algorithm

K = print(round(sqrt(nrow(training_data))))
## [1] 45

Use KNN algorithm

predictions <- knn(train=training_data , test=testing_data , cl=as.matrix(training_label) , k=round(sqrt(nrow(training_data))))

Calculating Accuracy of predictions using confusion matrix

We obtain accuracy of ~90%

Accuracy <- table(tesing_lables = testing_label$Species, knn_prediction = predictions )
confusionMatrix(Accuracy)
## Confusion Matrix and Statistics
## 
##              knn_prediction
## tesing_lables   1   2
##             1 248  15
##             2  35 195
##                                           
##                Accuracy : 0.8986          
##                  95% CI : (0.8685, 0.9238)
##     No Information Rate : 0.574           
##     P-Value [Acc > NIR] : < 2e-16         
##                                           
##                   Kappa : 0.7951          
##                                           
##  Mcnemar's Test P-Value : 0.00721         
##                                           
##             Sensitivity : 0.8763          
##             Specificity : 0.9286          
##          Pos Pred Value : 0.9430          
##          Neg Pred Value : 0.8478          
##              Prevalence : 0.5740          
##          Detection Rate : 0.5030          
##    Detection Prevalence : 0.5335          
##       Balanced Accuracy : 0.9024          
##                                           
##        'Positive' Class : 1               
## 

Create dataframe of Training data

df_training <- training_data
df_training$Class <- training_label$Species
df_training$Class <- as.factor(df_training$Class)

Create a scatterplot of Training Data

ggplot(data=df_training) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class)) +
  labs(y = "Aspect Ratio",x="Area") + 
  ggtitle("Scatter Plot of Pumpkin Seed Training Data") + theme(plot.title = element_text(hjust = 0.5)) 

We create a combined data frame for Testing Data

df_testing <- testing_data
df_testing$Class <- testing_label$Species
df_testing$Prediction <- predictions 
df_testing$Accuracy <- ifelse(df_testing$Class == df_testing$Prediction ,"Correct","Incorrect")
df_testing$Class = as.factor(df_testing$Class)

We create a scatter plot of the data to visualize correct and incorrect points

ggplot(data=df_testing) + geom_point(mapping=aes(x=Area,y=Aspect_Ration,shape=Class,colour=Accuracy)) +
  scale_color_manual(values = c("Correct" = "green", "Incorrect" = "black")) + 
  labs(y = "Aspect Ratio",x="Area") + 
  ggtitle("Scatter Plot of Pumpkin Seed Testing Data using KNN") + theme(plot.title = element_text(hjust = 0.5))